{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3d9f4abf",
   "metadata": {},
   "source": [
    "# Implementing Transformers\n",
    "\n",
    "This notebook will walk you through the internals of implementing self attention and transformer networks.  As with recurrent networks (and unlike convolutions), there is actually relatively little that is fundamentally new in their implementation, as it all largely involves an application of existing primitives you will have already implemented in your autodiff framework.  However, there is indeed one aspect of an efficient implementation that requires a slight generalization of an item we have discussed already: a _batch_ version of matrix multiplication.  This is required for both the minibatch version of attention as well as the common \"multihead\" version.  We will also briefly discuss some approaches to making Transformers more efficient."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e150202b",
   "metadata": {},
   "source": [
    "## Implementing self-attention\n",
    "\n",
    "Let's begin with a simple implementation of self-attention.  This essentially just implements the basic equation\n",
    "\n",
    "\\begin{equation}\n",
    "Y = \\mathrm{softmax}\\left(\\frac{KQ^T}{\\sqrt{d}}\\right)V\n",
    "\\end{equation}\n",
    "\n",
    "By convention, however, it's typical to implement self attention in terms of the actual inputs $X$ rather than the $K$, $Q$, and $V$ values themselves (i.e., instead of having the linear layer separately).  It's also common to have an output weight as well (even though this could in theory be folded into the $W_{KQV}$ terms), which applies an additional linear layer to the output of the the entire operation.  I.e., the full operation is given by\n",
    "\\begin{equation}\n",
    "Y = \\left(\\mathrm{softmax}\\left(\\frac{X W_K W_Q^T X^T}{\\sqrt{d}}\\right)X W_V \\right) W_o.\n",
    "\\end{equation}\n",
    "It's possible to also incorporate bias terms into each of these projections, though we won't bother with this, as it is less common for everything but the output weight, and then just largely adds complexity.\n",
    "\n",
    "Let's see what this implementation looks like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a35fb17d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "67a9d7ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax(Z):\n",
    "    Z = np.exp(Z - Z.max(axis=-1, keepdims=True))\n",
    "    return Z / Z.sum(axis=-1, keepdims=True)\n",
    "    \n",
    "def self_attention(X, mask, W_KQV, W_out):\n",
    "    K,Q,V = np.split(X@W_KQV, 3, axis=-1)\n",
    "    attn = softmax(K@Q.swapaxes(-1,-2) / np.sqrt(X.shape[-1]) + mask)\n",
    "    return attn@V@W_out, attn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbe332ed",
   "metadata": {},
   "source": [
    "We can compare this to PyTorch's self-attention implementation, the `nn.MultiheadAttention` layer (we'll cover what we mean by \"multi-head\" shortly).  Note that by default (mainly just to be similar to the RNN implementation and other sequence models, the `nn.MultiheadAttention` layer _also_ by default takes inputs in $(T,N,d)$ form (i.e, the batch dimension second.  But unlike for RNNs, this ordering doesn't make much sense for self-attention and Transformers: we will be computing the operation \"in parallel\" over all times points, instead of as a sequential model like for RNNs.  So we'll use the `batch_first=True` flag to make this a more natural dimension ordering for the inputs.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3cc24289",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., -inf, -inf, -inf, -inf],\n",
       "        [0., 0., -inf, -inf, -inf],\n",
       "        [0., 0., 0., -inf, -inf],\n",
       "        [0., 0., 0., 0., -inf],\n",
       "        [0., 0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "T = 5\n",
    "M = torch.triu(-float(\"inf\")*torch.ones(T,T),1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "30733bf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "T, d = 100, 64\n",
    "attn = nn.MultiheadAttention(d, 1, bias=False, batch_first=True)\n",
    "M = torch.triu(-float(\"inf\")*torch.ones(T,T),1)\n",
    "X = torch.randn(1,T,d)\n",
    "Y_, A_ = attn(X,X,X, attn_mask=M)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2779de31",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y, A = self_attention(X[0].numpy(), M.numpy(), \n",
    "                      attn.in_proj_weight.detach().numpy().T,\n",
    "                      attn.out_proj.weight.detach().numpy().T)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "f28d3c52",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8741974e-07\n",
      "1.3277154e-06\n"
     ]
    }
   ],
   "source": [
    "print(np.linalg.norm(A - A_[0].detach().numpy()))\n",
    "print(np.linalg.norm(Y - Y_[0].detach().numpy()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e22a711c",
   "metadata": {},
   "source": [
    "## Minibatching with batch matrix multiply\n",
    "\n",
    "Once we move from single example to minibatches, there is one additional subtlety that comes into play for self-attenion.  Recall that for _each_ sample in the minibatch, we will have to compute a matrix product, e.g., the $KQ^T$ term.  If we need to process examples in a minibatch, we will need to perform this matrix multiplication correspondingly for each sample.  This is an operation known as a batch matrix multiply."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05a54cdb",
   "metadata": {},
   "source": [
    "It may seem as though nothing is new here.  True, for an MLP it was possible to perform the entire batch equation as a single matrix multiplication, but didn't we similarly need to batch matrix multiplications for convolutional networks (after the im2col function)?  Or for RNNs?\n",
    "\n",
    "The answer is actually that no, previous to this we haven't needed the true batch matrix multiplication fuctionality.  The situations we had before involved the multiplication of a \"batched\" tensor by a _single_ weight matrix.  I.e., in a ConvNet, we had something like\n",
    "\\begin{equation}\n",
    "y = \\mathrm{im2col}(x) W\n",
    "\\end{equation}\n",
    "or in the batched setting\n",
    "\\begin{equation}\n",
    "y^{(i)} = \\mathrm{im2col}\\left(x^{(i)}\\right) W.\n",
    "\\end{equation}\n",
    "\n",
    "But this operation can be accomplished with \"normal\" matrix multiplication by just stacking the multiple samples into the matrix on the left\n",
    "\\begin{equation}\n",
    "\\begin{bmatrix}\n",
    "y^{(1)} \\\\ y^{(2)} \\\\ \\vdots \\\\ y^{(N)}\n",
    "\\end{bmatrix}\n",
    "= \n",
    "\\begin{bmatrix}\n",
    "\\mathrm{im2col}\\left(x^{(1)}\\right) \\\\\n",
    "\\mathrm{im2col}\\left(x^{(2)}\\right) \\\\\n",
    "\\vdots \\\\\n",
    "\\mathrm{im2col}\\left(x^{(N)}\\right) \\\\\n",
    "\\end{bmatrix}\n",
    "W.\n",
    "\\end{equation}\n",
    "This operation is just a normal matrix multiplication, so can be implemented e.g., using your framework so far, where matrix multiplication always operates on 2 dimensional NDArrays.\n",
    "\n",
    "Fortunately, numpy's `@` operator _already_ performs batch matrix multiplication for the case of multiple arrays of (the same) dimension more than 2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "f3a70fa6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10, 3, 5, 3)"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# illustration of batch matmul\n",
    "B = np.random.randn(10,3,5,4)\n",
    "C = np.random.randn(10,3,4,3)\n",
    "(B@C).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c28c527a",
   "metadata": {},
   "source": [
    "Let's see how this works with our self attention layer.  In fact, because of the judicious usage of `axis=-1` and similar terms, our layer works _exactly_ the same as it did before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "fa24a2cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 10\n",
    "M = torch.triu(-float(\"inf\")*torch.ones(T,T),1)\n",
    "X = torch.randn(N,T,d)\n",
    "Y_, A_ = attn(X,X,X, attn_mask=M)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "546135ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y, A = self_attention(X.numpy(), M.numpy(),\n",
    "                      attn.in_proj_weight.detach().numpy().T, \n",
    "                      attn.out_proj.weight.detach().numpy().T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "2d96c92e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5.5253105e-07\n",
      "3.97839e-06\n"
     ]
    }
   ],
   "source": [
    "print(np.linalg.norm(A - A_.detach().numpy()))\n",
    "print(np.linalg.norm(Y - Y_.detach().numpy()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "756fc65a",
   "metadata": {},
   "source": [
    "## Multihead attention\n",
    "\n",
    "Practical implementations of attention use what is called _multihead_ attention, which simply means that we run the self-attention mechansism of different subsets of the $K$, $Q$, $V$ terms, then concatenate them together.  Formally, we'll partition these terms as\n",
    "\\begin{equation}\n",
    "K = \\begin{bmatrix} K_1 & K_2 & \\cdots & K_{\\mathrm{heads}} \\end{bmatrix}\n",
    "\\end{equation}\n",
    "(and similarly for $Q$ and $V$.\n",
    "\n",
    "Then will form the self attention outputs\n",
    "\\begin{equation}\n",
    "Y_i = \\mathrm{softmax}\\left(\\frac{K_iQ_i^T}{\\sqrt{d/\\mathrm{heads}}}\\right)V_i\n",
    "\\end{equation}\n",
    "and then form the final ouput\n",
    "\\begin{equation}\n",
    "Y = \\begin{bmatrix} Y_1 & Y_2 & \\cdots & Y_{\\mathrm{heads}} \\end{bmatrix} W_o.\n",
    "\\end{equation}\n",
    "\n",
    "The advantage of multi-head attention is that applying a single self-attention layer to a \"high dimensional\" hidden state (i.e., where $d$ is large) seems to waste a lot of the information contained in the hidden layers.  Recall, for intance, that the terms in the self attention matrix would be proportation to $k_t^T q_s$.  If $k_t$ and $q_s$ are high dimensional, then a lot of \"internal structure\" could be lost to result in ultimately just one weighting term.  By breaking this up and computing multiple differen attention matrices, each of which weights different dimensions of the $V$ term, we avoid this problem, and practically lead to better performance.  Note however that the \"right\" tradeoff between the number of heads and $d$ is still rather heuristic in nature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "1acf3ad0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def multihead_attention(X, mask, heads, W_KQV, W_out):\n",
    "    N,T,d = X.shape\n",
    "    K,Q,V = np.split(X@W_KQV, 3, axis=-1)\n",
    "    K,Q,V = [a.reshape(N,T,heads,d//heads).swapaxes(1,2) for a in (K,Q,V)]\n",
    "    \n",
    "    attn = softmax(K@Q.swapaxes(-1,-2) / np.sqrt(d//heads) + mask)\n",
    "    return (attn@V).swapaxes(1,2).reshape(N,T,d) @ W_out, attn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "7c001698",
   "metadata": {},
   "outputs": [],
   "source": [
    "heads = 4\n",
    "attn = nn.MultiheadAttention(d, heads, bias=False, batch_first=True)\n",
    "Y_, A_ = attn(X,X,X, attn_mask=M)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "846e82b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y, A = multihead_attention(X.numpy(), M.numpy(), 4,\n",
    "                           attn.in_proj_weight.detach().numpy().T, \n",
    "                           attn.out_proj.weight.detach().numpy().T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "a8d14502",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([10, 100, 100])"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A_.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "127ae789",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10, 4, 100, 100)"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "A.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "e1fe95a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4.0823516e-06\n",
      "4.2045417e-07\n"
     ]
    }
   ],
   "source": [
    "print(np.linalg.norm(Y - Y_.detach().numpy()))\n",
    "print(np.linalg.norm(A.mean(1) - A_.detach().numpy()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63fe7db8",
   "metadata": {},
   "source": [
    "## Transformer Block\n",
    "\n",
    "Let's finally put all this together into a full transformer block.  Transformers simply amount to a self-attention block, with a residual layers and layer norm operation, followed by a two-layer feedforward network, with another residual layer and layer norm.  We can implement this in a few lines of code.  Note that in \"real\" implementations, the layer norm terms, etc, would actually have trainable scale/bias terms that add a bit more expressivity to the model.  This version we show will only be the same, for instance, at initialization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "c8daabf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def layer_norm(Z, eps):\n",
    "    return (Z - Z.mean(axis=-1, keepdims=True)) / np.sqrt(Z.var(axis=-1, keepdims=True) + eps)\n",
    "    \n",
    "def relu(Z):\n",
    "    return np.maximum(Z,0)\n",
    "\n",
    "def transformer(X, mask, heads, W_KQV, W_out, W_ff1, W_ff2, eps):\n",
    "    Z = layer_norm(multihead_attention(X, mask, heads, W_KQV, W_out)[0] + X, eps)\n",
    "    return layer_norm(Z + relu(Z@W_ff1)@W_ff2, eps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "9b3c6bd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "trans = nn.TransformerEncoderLayer(d, heads, dim_feedforward=128, dropout=0.0, batch_first=True)\n",
    "trans.linear1.bias.data.zero_()\n",
    "trans.linear2.bias.data.zero_();\n",
    "Y_ = trans(X, M)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "a47cc0dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = transformer(X.numpy(), M.numpy(), heads,\n",
    "                trans.self_attn.in_proj_weight.detach().numpy().T, \n",
    "                trans.self_attn.out_proj.weight.detach().numpy().T,\n",
    "                trans.linear1.weight.detach().numpy().T,\n",
    "                trans.linear2.weight.detach().numpy().T,\n",
    "                trans.norm1.eps)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "676dfc0f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.7750326e-05\n"
     ]
    }
   ],
   "source": [
    "print(np.linalg.norm(Y - Y_.detach().numpy()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e94da64",
   "metadata": {},
   "source": [
    "## The question for \"efficient Transformers\"\n",
    "\n",
    "Since the Transformer was first proposed, there have been endless attempts made to make different \"efficient\" versions of the operation.  The key drawback of transformers, we have seen, is that they require forming a the $T \\times T$ attention matrix and multiplying by $V$ (an $O(T^2d)$ operation)\n",
    "\\begin{equation}\n",
    "\\mathrm{softmax}\\left(\\frac{KQ^T}{\\sqrt{d}}\\right)V\n",
    "\\end{equation}\n",
    "If $T$ is much larger than $d$ (e.g., the sequence is very long, then this operation is quite costly).\n",
    "\n",
    "There are essentially two approaches to making the approach more efficient: by attempting the represent the attention matrix\n",
    "\\begin{equation}\n",
    "A = \\mathrm{softmax}\\left(\\frac{KQ^T}{\\sqrt{d}}\\right)\n",
    "\\end{equation}\n",
    "either using _sparsity_ or using _low rank_ structure.  In general, of course, this matrix neither sparse nor low rank.  But we could simply dicate, for example, that we will only compute some subset of the attention weights, thereby decreasing the number of inner products we need to perform (this is the basis of the so-called \"Sparse Attention\" layer: similar approaches have been proposed a number of times, but [this](https://arxiv.org/abs/1904.10509) is one such example).  Alternatively, one could try to infer some kind of hard sparsity by e.g., triangle inequalities or other similar instances (because, remember, we are computing what amounts to a similarly metric between the $x$ terms at different times).\n",
    "\n",
    "Alternatively, we could try to represent $A$ in _low rank_ form instead.  To see why this could be appealing, consider the case where we don't have a softmax operation at all, but instead used the \"attention\" layer \n",
    "\\begin{equation}\n",
    "\\left(\\frac{KQ^T}{\\sqrt{d}}\\right)V\n",
    "\\end{equation}\n",
    "In this case, if $T \\gg d$, we could instead perform our multiplication in the order $K(Q^T V)$, which would only have complexity $O(Td^2)$, potentially much smaller.  And some papers infact advocate for this very thing, or alternatively try to find a low-rank representation of the actual attention weights, to similar effects.\n",
    "\n",
    "The thing to keep in mind with all these \"efficient\" alternatives (and if you have been reading the literation surrounding Transformers, you have likely seen a _ton_ of these), is whether they are actually more efficient, for an equivalent level of performance, once real execution speed in taken into account.  My best understanding of the current situation is that 1) explicit sparse self attention is indeed sometimes useful for models that want very long history, but that 2) most of the \"efficient\" transformer mechanisms that use low rank structure or inferred sparsity structure don't improve much in practice over traditional attention."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
